home *** CD-ROM | disk | FTP | other *** search
/ Language/OS - Multiplatform Resource Library / LANGUAGE OS.iso / sml_nj / 93src.lha / src / cps / expand.sml < prev    next >
Encoding:
Text File  |  1993-01-27  |  15.5 KB  |  465 lines

  1. (* Copyright 1989 by AT&T Bell Laboratories *)
  2. structure Expand :
  3.     sig val expand : {function: CPS.function,
  4.               bodysize: int,
  5.               unroll: bool,
  6.               afterClosure: bool, do_headers: bool,
  7.               click: string->unit} -> CPS.function
  8.         end =
  9. struct
  10.  
  11.  open Access CPS
  12.  structure CG = System.Control.CG
  13.  
  14.  fun map1 f (a,b) = (f a, b)
  15.  
  16.  fun sum f = let fun h [] = 0 
  17.            | h (a::r) = f a + h r
  18.          in h
  19.          end
  20.  
  21.  fun last0[x]=x | last0(a::b)=last0 b | last0 _ = 0
  22.               
  23.  fun sameName(x,VAR y) = Access.sameName(x,y) 
  24.    | sameName(x,LABEL y) = Access.sameName(x,y) 
  25.    | sameName _ = ()
  26.  
  27.  datatype mode = ALL | NO_UNROLL | UNROLL of int | HEADERS
  28.  
  29. fun expand{function=(fvar,fargs,cexp),unroll,bodysize,click,afterClosure,
  30.        do_headers} =
  31.  let
  32.    val clicked_any = ref false
  33.    val click = fn z => (click z; clicked_any := true)
  34.    val debug = !CG.misc1 (* false *)
  35.    val debugprint = if debug then System.Print.say else fn _ => ()
  36.    val debugflush = if debug then System.Print.flush else fn _ => ()
  37.    fun label v = if afterClosure then LABEL v else VAR v
  38.    datatype info = Fun of {escape: int ref, call: int ref, size: int ref,
  39.                args: lvar list, body: cexp,
  40.                invariant: bool list ref, (* one for each arg *)
  41.                unroll_call: int ref, level: int,
  42.                within: bool ref}
  43.          | Arg of {escape: int ref, savings: int ref,
  44.                record: (int * lvar) list ref}
  45.          | Sel of {savings: int ref}
  46.          | Rec of {escape: int ref, size: int,
  47.                vars: (value * accesspath) list}
  48.          | Real
  49.          | Const
  50.          | Other
  51.  
  52.    exception Expand
  53.    val m : info Intmap.intmap = Intmap.new(128,Expand)
  54.    val note = Intmap.add m
  55.    val get = Intmap.map m
  56.    fun getval(VAR v) = get v
  57.      | getval(LABEL v) = get v
  58.      | getval(REAL _) = Other (*Real*)
  59.      | getval(INT _) = Const
  60.      | getval _ = Other
  61.    fun call(v, args) = case getval v
  62.             of Fun{call,within=ref false,...} => inc call
  63.              | Fun{call,within=ref true,unroll_call,
  64.                    args=vl,invariant,...} => 
  65.                  let fun g(VAR x :: args, x' :: vl, i::inv) =
  66.                        (i andalso x=x') :: g(args,vl,inv)
  67.                    | g( _ :: args, _ :: vl, i::inv) =
  68.                        false :: g(args,vl,inv)
  69.                    | g _ = nil
  70.                   in inc call; inc unroll_call;
  71.                  invariant := g(args,vl,!invariant)
  72.                  end
  73.              | Arg{savings,...} => savings := !savings+1
  74.              | Sel{savings} => savings := !savings+1
  75.              | _ => ()
  76.    fun escape v = case getval v
  77.            of Fun{escape,...} => inc escape
  78.             | Arg{escape,...} => inc escape
  79.             | Rec{escape,...} => inc escape
  80.             | _ => ()
  81.    fun escapeargs v = case getval v
  82.                of Fun{escape,...} => inc escape
  83.             | Arg{escape,savings, ...} =>
  84.                    (inc escape; savings := !savings + 1)
  85.             | Sel{savings} => savings := !savings + 1
  86.             | Rec{escape,...} => inc escape
  87.             | _ => ()
  88.    fun unescapeargs v = case getval v
  89.              of Fun{escape,...} => dec escape
  90.               | Arg{escape,savings, ...} =>
  91.                       (dec escape; savings := !savings - 1)
  92.               | Sel{savings} => savings := !savings - 1
  93.               | Rec{escape,...} => dec escape
  94.               | _ => ()
  95.    fun setsize(f,n) = case get f of Fun{size,...} => (size := n; n)
  96.    fun notearg v = (note (v,Arg{escape=ref 0,savings=ref 0, record=ref []}))
  97.    fun notereal v = note (v,Other(*Real*))
  98.    fun noteother v = note (v,Other)
  99.    fun enter level (f,vl,e) = 
  100.               (note(f,Fun{escape=ref 0, call=ref 0, size=ref 0,
  101.               args=vl, body=e, within=ref false,
  102.               unroll_call = ref 0, 
  103.               invariant = ref(map (fn _ => !CG.invariant) vl),
  104.               level=level});
  105.            app notearg vl)
  106.    fun noterec(w, vl, size) = note (w,Rec{size=size,escape=ref 0,vars=vl})
  107.    fun notesel(i,v,w) = (note (w, Sel{savings=ref 0});
  108.              case getval v
  109.               of Arg{savings,record,...} => (inc savings;
  110.                           record := (i,w)::(!record))
  111.                | _ => ())
  112.  
  113.    fun save(v,k) = case getval v
  114.             of Arg{savings,...} => savings := !savings + k
  115.              | Sel{savings} => savings := !savings + k
  116.              | _ => ()
  117.    fun nsave(v,k) = case getval v
  118.              of Arg{savings,...} => savings := k
  119.               | Sel{savings} => savings := k
  120.               | _ => ()
  121.    fun savesofar v = case getval v 
  122.               of Arg{savings,...} => !savings
  123.                | Sel{savings} => !savings
  124.                | _ => 0
  125.  
  126.    fun within f func arg =
  127.         case get f of Fun{within=w,...} => 
  128.         (w := true; func arg before (w := false))
  129.  
  130.    val rec prim = fn (level,vl,e) =>
  131.        let fun vbl(VAR v) = (case get v of Rec _ => 0 | _ => 1)
  132.          | vbl _ = 0
  133.        val nonconst = sum vbl vl
  134.        val sl = map savesofar vl
  135.        val afterwards = pass1 level e
  136.        val zl = map savesofar vl
  137.        val overhead = length vl + 1
  138.        val potential = overhead
  139.        val savings = case nonconst of
  140.                1 => potential
  141.              | 2 => potential div 4
  142.              | _ => 0
  143.        fun app3 f = let fun loop (a::b,c::d,e::r) = (f(a,c,e); loop(b,d,r))
  144.                   | loop _ = ()
  145.             in loop
  146.             end
  147.        in app3(fn (v,s,z)=> nsave(v,s + savings + (z-s))) (vl,sl,zl);
  148.       overhead+afterwards
  149.        end
  150.  
  151.    and primreal = fn (level,(_,vl,w,e)) =>
  152.        (notereal w;
  153.     app (fn v => save(v,1)) vl;
  154.     2*(length vl + 1) + pass1 level e)
  155.  
  156.    and pass1 : int -> cexp -> int= fn level =>
  157.     fn RECORD(_,vl,w,e) =>
  158.     (app (escape o #1) vl;
  159.      noterec(w,vl,length vl);
  160.      2 + length vl + pass1 level e)
  161.      | SELECT (i,v,w,e) => (notesel(i,v,w); 1 + pass1 level e)
  162.      | OFFSET (i,v,w,e) => (noteother w; 1 + pass1 level e)
  163.      | APP(f,vl) => (call(f,vl); 
  164.              app escapeargs vl; 
  165.              1 + ((length vl + 1) quot 2))
  166.      | FIX(l, e) => 
  167.       (app (enter level) l; 
  168.        sum (fn (f,_,e) => setsize(f, within f (pass1 (level+1)) e)) l + length l + pass1 level e)
  169.      | SWITCH(v,_,el) => let val len = length el
  170.                  val jumps = 4 + len
  171.                  val branches = sum (pass1 level) el
  172.               in save(v, (branches*(len-1)) quot len + jumps);
  173.                  jumps+branches
  174.              end
  175.      | BRANCH(_,vl,c,e1,e2) =>
  176.        let fun vbl(VAR v) = (case get v of Rec _ => 0 | _ => 1)
  177.          | vbl _ = 0
  178.        val nonconst = sum vbl vl
  179.        val sl = map savesofar vl
  180.        val branches = pass1 level e1 + pass1 level e2
  181.        val zl = map savesofar vl
  182.        val overhead = length vl
  183.        val potential = overhead + branches quot 2
  184.        val savings = case nonconst of
  185.                1 => potential
  186.              | 2 => potential div 4
  187.              | _ => 0
  188.        fun app3 f = let fun loop (a::b,c::d,e::r) = (f(a,c,e); loop(b,d,r))
  189.                   | loop _ = ()
  190.             in loop
  191.             end
  192.        in app3(fn (v,s,z)=> nsave(v,s + savings + (z-s) quot 2)) (vl,sl,zl);
  193.       overhead+branches
  194.        end
  195.      | LOOKER(_,vl,w,e) => (noteother w; prim(level,vl,e))
  196.      | SETTER(_,vl,e) => prim(level,vl,e)
  197.      | ARITH(args as (P.floor,_,_,_)) => primreal (level,args)
  198.      | ARITH(args as (P.round,_,_,_)) => primreal (level,args)
  199.      | ARITH(args as (P.fadd,_,_,_)) => primreal (level,args)
  200.      | ARITH(args as (P.fdiv,_,_,_)) => primreal (level,args)
  201.      | ARITH(args as (P.fmul,_,_,_)) => primreal (level,args)
  202.      | ARITH(args as (P.fsub,_,_,_)) => primreal (level,args)
  203.      | ARITH(_,vl,w,e) => (noteother w; prim(level,vl,e))
  204.      | PURE(P.fnegd,[v],w,e) => (notereal w; save(v,1); 4+(pass1 level e))
  205.      | PURE(P.fabsd,[v],w,e) => (notereal w; save(v,1); 4+(pass1 level e))
  206.      | PURE(P.real,vl,w,e) => (notereal w; prim(level,vl,e))
  207.      | PURE(_,vl,w,e) => (noteother w; prim(level,vl,e))
  208.  
  209.    fun substitute(args,wl,e,level,alpha) =
  210.     let exception Alpha
  211.     val vm : value Intmap.intmap = Intmap.new(16, Alpha)
  212.     fun use(v0 as VAR v) = (Intmap.map vm v handle Alpha => v0)
  213.       | use(v0 as LABEL v) = (Intmap.map vm v handle Alpha => v0)
  214.       | use x = x
  215.     fun def v = if alpha
  216.                  then let val w = dupLvar v 
  217.                in Intmap.add vm (v, VAR w); w
  218.               end
  219.              else v 
  220.     fun defl v = if alpha
  221.                  then let val w = dupLvar v 
  222.                in Intmap.add vm (v, label w);
  223.                    w
  224.               end
  225.              else v
  226.     fun bind(a::args,w::wl) = 
  227.            (sameName(w,a); Intmap.add vm (w,a); bind(args,wl))
  228.       | bind _ = ()
  229.     val rec g =
  230.        fn RECORD(k,vl,w,ce) => RECORD(k,map (map1 use) vl, def w, g ce)
  231.     | SELECT(i,v,w,ce) => SELECT(i, use v, def w, g ce)
  232.     | OFFSET(i,v,w,ce) => OFFSET(i, use v, def w, g ce)
  233.     | APP(v,vl) => APP(use v, map use vl)
  234.     | FIX(l,ce) => 
  235.       let fun h1(f,vl,e) = (f,defl f, vl, e)
  236.           fun h2(f,f',vl,e) =
  237.           let val vl' = map def vl
  238.               val e'= g e
  239.           in (f', vl', e')
  240.           end
  241.        in FIX(map h2(map h1 l), g ce)
  242.       end
  243.     | SWITCH(v,c,l) => SWITCH(use v, def c, map g l)
  244.     | LOOKER(i,vl,w,e) => LOOKER(i, map use vl, def w, g e)
  245.     | ARITH(i,vl,w,e) => ARITH(i, map use vl, def w, g e)
  246.     | PURE(i,vl,w,e) => PURE(i, map use vl, def w, g e)
  247.     | SETTER(i,vl,e) => SETTER(i, map use vl, g e)
  248.     | BRANCH(i,vl,c,e1,e2) => BRANCH(i, map use vl, def c, g e1, g e2)
  249.     val cexp = (bind(args,wl); g e)
  250.     in  (*debugprint("\nSize=" ^ makestring(pass1 level cexp)); debugprint " "; *)
  251.     if alpha then pass1 level cexp else 0;
  252.     cexp
  253.     end
  254.  
  255.    fun whatsave(acc, size, (v:value)::vl, a::al) =
  256.        if acc>=size
  257.        then acc
  258.        else
  259.        (case get a of
  260.       Arg{escape=ref esc,savings=ref save,record=ref rl} =>
  261.       let val (this, nvl: value list, nal) =
  262.            case getval v
  263.         of Fun{escape=ref 1,...} =>
  264.             (if esc>0 then save else 6+save,vl,al)
  265.          | Fun _ => (save,vl,al)
  266.          | Rec{escape=ref ex,vars,size} =>
  267.               let exception Chase
  268.               fun chasepath(v,OFFp 0) = v
  269.                 | chasepath(v, SELp(i,p)) =
  270.                    (case getval v
  271.                  of Rec{vars,...} =>
  272.                     chasepath(chasepath(nth(vars,i)),p)
  273.                   | _ => raise Chase)
  274.                 | chasepath _ = raise Chase
  275.               fun loop([],nvl,nal) = 
  276.                   (if ex>1 orelse esc>0
  277.                    then save
  278.                    else save+size+2,nvl,nal)
  279.                 | loop((i,w)::rl,nvl,nal) =
  280.                    loop(rl,
  281.                   chasepath(nth(vars,i))::nvl,
  282.                     w::nal)
  283.                in loop(rl,vl,al)
  284.               handle Chase => (0,vl,al)
  285.                    | Nth => (0,vl,al)
  286.               end 
  287.         (* | Real => (save,vl,al)*)
  288.          | Const => (save,vl,al)
  289.          | _ => (0,vl,al)
  290.       in whatsave(acc+this - (acc*this) quot size, size, nvl,nal)
  291.       end
  292.     | Sel{savings=ref save} =>
  293.       let val this =
  294.           case v
  295.            of VAR v' => (case get v' of
  296.                   Fun _ => save
  297.                 | Rec _ => save
  298.                 | _ => 0)
  299.         | _ => save
  300.       in whatsave(acc + this - (acc*this) quot size,size, vl,al)
  301.       end)
  302.      | whatsave(acc,size,_,_) = acc
  303.  
  304.    fun beta(n, (* how many expansions we are within *)
  305.         d, (* path length from start of current function *)
  306.         u,  (* unroll-info *)
  307.         e (* expression to traverse *)
  308.         ) = case e
  309.     of RECORD(k,vl,w,ce) => RECORD(k,vl, w, beta(n,d+2+length vl, u, ce))
  310.      | SELECT(i,v,w,ce) => SELECT(i, v, w, beta(n,d+1, u, ce))
  311.      | OFFSET(i,v,w,ce) => OFFSET(i, v, w, beta(n,d+1, u, ce))
  312.      | APP(v,vl) => 
  313.      (case getval v
  314.        of info as Fun{args,body,...} =>
  315.            if should_expand(n,d,u,e,info)
  316.            then let val new = beta(n+1, d+1, u, 
  317.                       substitute(vl,args,body,
  318.                          case u of UNROLL lev => lev
  319.                                  | _ => 0,
  320.                          true))
  321.             in click "^";
  322.                case v of VAR vv => debugprint(makestring vv) | _ => ();
  323.                app unescapeargs vl;
  324.                new
  325.             end
  326.             else e
  327.         | _ => e)
  328.      | FIX(l,ce) => FIX(if n<1 then map (fundef(n,d,u)) l else l, 
  329.             beta(n,d+length l, u,ce))
  330.      | SWITCH(v,c,l) => SWITCH(v, c, map (fn e => beta(n,d+2,u,e)) l)
  331.      | LOOKER(i,vl,w,e) => LOOKER(i, vl, w, beta(n,d+2,u,e))
  332.      | ARITH(i,vl,w,e) => ARITH(i, vl, w, beta(n,d+2,u,e))
  333.      | PURE(i,vl,w,e) => PURE(i, vl, w, beta(n,d+2,u,e))
  334.      | SETTER(i,vl,e) => SETTER(i, vl, beta(n,d+2,u,e))
  335.      | BRANCH(i,vl,c,e1,e2) => BRANCH(i, vl, c,beta(n,d+2,u,e1), 
  336.                       beta(n,d+2,u,e2))
  337.  
  338.     and should_expand(n,d,HEADERS,e,_) = false
  339.       | should_expand(n,d,u,e as APP(v,vl), 
  340.               Fun{escape,call,unroll_call,size=ref size,args,body,
  341.               level,within=ref within,...}) =
  342.       let val stupidloop =  (* prevent infinite loops  at compile time *)
  343.         case (v,body) 
  344.          of (VAR vv, APP(VAR v',_)) => vv=v' 
  345.           | (LABEL vv, APP(LABEL v',_)) => vv=v' 
  346.           | _ => false
  347.     val calls = case u of UNROLL _ => !unroll_call | _ => !call
  348.     val small_fun_size = case u of UNROLL _ => 0 | _ => 50
  349.     val savings = whatsave(0,size,vl,args)
  350.     val predicted = 
  351.         let val real_increase = size-savings-(1+length vl)
  352.         in  real_increase * calls - 
  353.         (* don't subtract off the original body if
  354.            the original body is huge (because we might
  355.            have guessed wrong and the consequences are
  356.            too nasty for big functions); or if we're
  357.            in unroll mode *)
  358.         (if size < small_fun_size then size else 0)
  359.         end
  360.     val depth = 2 and max = 2
  361.     val increase = (bodysize*(depth - n)) quot depth
  362.  
  363.     in if false andalso debug
  364.       then (CPSprint.show System.Print.say e;
  365.         debugprint(makestring predicted); debugprint "   "; 
  366.         debugprint(makestring increase);
  367.         debugprint"   "; debugprint (makestring n); debugprint "\n")
  368.      else ();
  369.  
  370.        not stupidloop
  371.        andalso case u
  372.         of UNROLL lev => 
  373.          (* Unroll if: the loop body doesn't make function
  374.             calls orelse "unroll_recur" is turned on; andalso 
  375.             we are within the definition of the function; 
  376.             andalso it looks like things won't grow too much.
  377.           *)
  378.            (!CG.unroll_recur orelse level >= lev)
  379.            andalso n <= max
  380.            andalso within andalso predicted <= increase
  381.          | NO_UNROLL =>
  382.            !unroll_call = 0 andalso
  383.            not within andalso n <= max andalso
  384.            (predicted <= increase  
  385.              orelse (!escape=0 andalso calls = 1))
  386.          | HEADERS => false
  387.          | ALL => n <= max andalso
  388.            (predicted <= increase  
  389.              orelse (!escape=0 andalso calls = 1))
  390.   end
  391.  
  392.    and fundef (n,d,HEADERS) (f,vl,e) = 
  393.     let val Fun{within,escape=ref escape,call,unroll_call,
  394.         invariant=ref inv,...} = get f
  395.  
  396.      in within := true;
  397.     (if  escape = 0 andalso !unroll_call > 0
  398.          andalso (!call - !unroll_call > 1 orelse exists (fn t=>t) inv)
  399.      then let val f'::vl' = map dupLvar (f::vl)
  400.           val within' = ref true
  401.           fun drop(false::r,a::s) = a::drop(r,s)
  402.             | drop(true::r,_::s) = drop(r,s)
  403.             | drop _ = nil
  404.           val e' =substitute(label f' :: map VAR (drop(inv,vl')),
  405.                      f :: drop(inv,vl),
  406.                      beta(n,0,HEADERS,e),
  407.                      0, false) 
  408.            in click "!"; debugprint(makestring f);
  409.            
  410.           enter 0 (f',vl',e');
  411.           (f,vl,FIX([(f',vl',e')], APP(label f', map VAR vl)))
  412.          end
  413.     else (f,vl,beta(n,0,HEADERS,e)))
  414.  
  415.         before within := false
  416.    end
  417.  
  418.     | fundef (n,d,u) (f,vl,e) = 
  419.     let val Fun{level,within,escape=ref escape,...} = get f
  420.  
  421.     val u' = case u of UNROLL _ => UNROLL level | _ => u
  422.  
  423.     in if (case e
  424.         of FIX([(g,[b,k],APP _)], APP(VAR c,[VAR g'])) =>
  425.                c=last0 vl andalso g=g'
  426.          | APP _ => escape > 0
  427.          | _ => false)
  428.         then (f,vl,e) (* Don't contract eta-splits *) 
  429.     else (within := true; (f,vl,beta(n,0,u',e)) before within := false)
  430.    end
  431.  
  432.   in notearg fvar; app notearg fargs;
  433. (*     if !CG.printit then CPSprint.show System.Print.say cexp
  434.      else ();
  435. *)     debugprint("\nExpand   "); debugprint(makestring(pass1 0 cexp));
  436.      if unroll
  437.      then let val _ = (debugprint("  (unroll)\n"); debugflush());
  438.           val e' = beta(0,0,UNROLL 0,cexp)
  439.           in if !clicked_any 
  440.              then expand{function=(fvar, fargs, e'),
  441.                  bodysize=bodysize,click=click,unroll=unroll,
  442.                  afterClosure=afterClosure,
  443.                  do_headers=do_headers}
  444.          else ((*debugprint("\nExpand\n"); 
  445.                  debugflush();
  446.                (fvar, fargs, beta(0,0,ALL,cexp)) *)
  447.                (fvar, fargs, e'))
  448.           end
  449.      else if !CG.unroll
  450.      then let val _ = (debugprint(" (headers)\n"); debugflush())
  451.           val e' = if do_headers then beta(0,0,HEADERS,cexp) else cexp
  452.            in if !clicked_any
  453.           then expand{function=(fvar,fargs,e'),
  454.                   bodysize=bodysize,click=click,unroll=unroll,
  455.                   afterClosure=afterClosure, do_headers=false}
  456.           else (debugprint("\n  (non-unroll)\n"); debugflush();
  457.             (fvar, fargs, beta(0,0,NO_UNROLL,e')))
  458.           end
  459.      else (debugprint("\n"); debugflush();
  460.            (fvar, fargs, beta(0,0,ALL,cexp)))
  461.      
  462.  end
  463.  
  464. end
  465.